import requests
import functools
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import os
import tiktoken
from datetime import datetime, timedelta
import subprocess
import zhipusdk
zhipusdk.api_key = "872fd844f1c0b03ecf9792215fd9a9fd.Io7ymPlK0AxLcWf3"


from core import setting, mgClient
import pymongo as mg
db = mgClient['chat']
traffic_collection = db.get_collection('traiffic')
traffic_collection_4 = db.get_collection('traiffic_4')
traffic_collection_poe = db.get_collection('traiffic_poe')
user_collection = db.get_collection('user')

def add_user(user_key, token_num):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        user_collection.update_one({"user_key": user_key}, {"$set": {"token_num": token_num}})
    else:
        user_collection.insert_one({"user_key": user_key, "token_num": token_num})

def check_user_token(user_key):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        return user["token_num"]
    else:
        return 0

def consume_token(user_key, token_num):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        token_num = user["token_num"] - token_num
        user_collection.update_one({"user_key": user_key}, {"$set": {"token_num": token_num}})

def generate_error_user():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "没钱啦！"
    yield response

def num_tokens_from_message(messages, model="gpt-3.5-turbo-0301"):
    """
    It takes a list of messages and returns the number of tokens that would be used to encode them
    
    :param messages: a list of messages, where each message is a dictionary with keys "name" and
    "content"
    :param model: The model to use, defaults to gpt-3.5-turbo-0301 (optional)
    :return: The number of tokens used by a list of messages.
    """
    
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    if model == "gpt-3.5-turbo-0301":  # note: future models may deviate from this
        num_tokens = 0
        for message in messages:
            num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":  # if there's a name, the role is omitted
                    num_tokens += -1  # role is always required and always 1 token
        num_tokens += 2  # every reply is primed with <im_start>assistant
        return num_tokens
    else:
        raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")

    
def get_api_key(api_keys, api_key_index):
    """
    It returns the next API key in the list of API keys
    :return: The api_key is being returned.
    """
    print(api_key_index)
    api_key = api_keys[api_key_index]
    api_key_index = (api_key_index + 1) % len(api_keys)
    return api_key, api_key_index

def get_api_keys(key_file_path):
    """
    It opens the file at the path specified by the argument key_file_path, reads each line of the
    file, strips the newline character from each line, and returns a list of the lines
    
    :param key_file_path: The path to the file containing the API keys
    :return: A list of API keys
    """
    api_keys = []
    with open(key_file_path, 'r') as f:
        for line in f:
            api_keys.append(line.strip())
    return api_keys

def check_traffic(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    if ip != "221.215.48.6":
        traffic = traffic_collection.find_one({"ip": ip})  # 查找指定IP地址的流量信息
        if not traffic:
            traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
            traffic_collection.insert_one(traffic)  # 插入新的流量信息
            return False
        else:
            recent_count = sum([t["count"] for t in traffic["data"]])
            # 删除超过24小时的旧流量数据
            traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
            if recent_count >= 400:  # 设置流量限制为10分钟内最多60个请求
                return True
            else:
                traffic["data"].append({"timestamp": now, "count": 1})
                traffic_collection.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
                return False
    else:
        return False
def check_traffic_poe(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection_poe.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection_poe.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 10:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection_poe.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False
def check_traffic_4(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection_4.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection_4.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 100:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection_4.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False
# check if the ip is in China
def check_ip_country(ip):
    """
    If the IP address is not in China, return True, otherwise return False
    
    :param ip: the ip address to be checked
    :return: a boolean value.
    """
    try:
        cmd = ["geoiplookup", ip]
        result = subprocess.Popen(cmd, stdout=subprocess.PIPE)
        output, error = result.communicate()
        geo_location = output.decode() # 将结果转换成字符串并去除首位的空格
        if "China" not in geo_location:
            return True
        else:
            return False
    except:
        return True
def generate_country_error():
    """
    It generates an error message if the country is not in the list of countries.
    """
    # This generator yields a single JSON response
    response = "您的IP地址不在大陆境内，无法使用本服务"
    yield response
def generate_traffic():
    """
    It yields a single JSON response
    """
    # This generator yields a single JSON response
    response = "每天的聊天次数有限哦"
    yield response
def generate_error_key():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "每个IP每24小时只能使用20次4.0，您必须申请一个有效key才能继续免费使用本服务，具体方式请访问微信链接：https://mp.weixin.qq.com/s/1ZbctH6Iaa7LZkguW6ZnVA"
    yield response
def generate_error_poe():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "每个IP每24小时只能使用10次claude+和claude-100k，您必须申请一个有效key才能继续免费使用本服务，具体方式请访问微信链接：https://mp.weixin.qq.com/s/1ZbctH6Iaa7LZkguW6ZnVA"
    yield response
def generate_error():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "出了点小状况，大概率是开发者太穷租不起大流量，或者调用的服务有问题，可以再试一次吗？"
    yield response
def generate_error_max():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "我们的聊天内容太多了，您最好刷新界面再试一次，注意，刷新后之前的对话内容将会丢失。"
    yield response

async def async_request_glm(prompt: str):
    print(prompt)
    response = zhipusdk.model_api.sse_invoke(
        model="chatglm_130b",
        prompt=prompt,
        temperature=0.95,
        top_p=0.7,
        incremental=True
    )
    print(response)
    return response

async def async_request(prompt: str, api_key: str):
    
    data = {
        "model": "gpt-3.5-turbo-16k",
        "messages": prompt,
        "temperature": 1,
        # "max_tokens": 4096,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer "+api_key,
    }
    url = "https://api.openai.com/v1/chat/completions"
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 3)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        response = requests.Response()
        response.status_code = 500
    return response
async def async_request_4(prompt: str):
    
    data = {
        "model": "gpt-4-0613",
        "messages": prompt,
        "temperature": 1,
        # "max_tokens": 4096,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer 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",
    }
    url = "https://api.openai.com/v1/chat/completions"
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 3)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        # 返回一个空的response，状态码为404
        response = requests.Response()
        response.status_code = 500
    return response


def get_save_file_path(client_ip,message):
    """
    If the client_ip directory doesn't exist, create it. Then, find the latest file in the
    directory. If there are no files, create a file named 1.txt. If there are files, and the message
    is longer than 2 characters, write to the latest file. If the message is 2 characters or less,
    create a new file with the name of the latest file + 1
    
    :param client_ip: the IP address of the client
    :param message: the message sent by the client
    :return: The file name of the file to be written to.
    """
    
    # check if the clien_ip dir is exist, if not, create it
    if not os.path.exists(os.path.join('log',client_ip)):
        os.makedirs(os.path.join('log',client_ip))
    # find the latest file in the dir
    file_list = os.listdir(os.path.join('log',client_ip))
    file_list.sort(key=lambda fn: os.path.getmtime(os.path.join('log',client_ip, fn)))
    if len(file_list) == 0:
        file_name = os.path.join('log',client_ip,'1.txt')
    else:
        file_name = os.path.join('log',client_ip,str(int(file_list[-1].split('.')[0])+1)+'.txt')
            
    # write the message in the json to the file
    with open(file_name, 'a') as f:
        f.write(str(message[-1]))
        # for i in range(len(message)):
        #     f.write(str(message[i])+'\n')
    # write the response in the json to the file
    # with open(file_name, 'a') as f:
    #     write_content = "{\'role\': \'assistant\', \'content\': \'" 
    #     f.write(write_content)
    return file_name

def generate_text_glm(response,save_file_path=None,save_flag=0):
    
    for event in response.events():
        text = event.data
        # text = text.replace('\\n','\n').replace('\\t','\t')
        # text = text.replace('\\n','\n')
        yield text
def generate_text(response,save_file_path=None, save_flag = 0,user_key=None):
    """
    It takes the response, and then it iterates through the response,
    and it finds the text that is in the response, and then it
    yields it
    
    :param response: the response from the server
    :param summary_flag: if it's 1, then the function will return a summary of the conversation
    before the answer, defaults to 0 (optional)
    """
    all_text = ''
    for event in response.iter_content(chunk_size=None):
        
            response_str = event.decode('utf-8')  # convert bytes to string
            strs = response_str.split('data: ')[1:]
            for str in strs:
                try:
                    response_dict = json.loads(str)  # convert string to dictionary
                    if "choices" in response_dict and response_dict["choices"]:
                        text = response_dict["choices"][0]["delta"].get("content")
                        if text:
                            if save_flag:
                                with open(save_file_path,'a') as f:
                                    f.write(text)
                            text = text.replace('\\n','\n').replace('\\t','\t')
                            
                            text = text.replace('\\n','\n')
                            all_text += text
                            yield text
                except:
                    pass
    if save_flag:
        with open(save_file_path,'a') as f:
            f.write("\'}\n")
    if user_key:
        message = [{"role":"assistant","content":all_text}]
        token_num = num_tokens_from_message(message)
        print('adada')
        consume_token(user_key,token_num)
        print('adadfsaf')
def generate_text_poe(client,message,mode = "a2_2"):
    for chunk in client.send_message(mode, message):
        yield chunk["text_new"]
    client.send_chat_break(mode)
    
def generate_text_json(response):
    """
    It takes the response, and then it iterates through the response,
    and it finds the text that is in the response, and then it
    yields it
    
    :param response: the response from the server
    :param summary_flag: if it's 1, then the function will return a summary of the conversation
    before the answer, defaults to 0 (optional)
    """
    all_text = ''
    for event in response.iter_content(chunk_size=None):
        
            response_str = event.decode('utf-8')  # convert bytes to string
            strs = response_str.split('data: ')[1:]
            for str in strs:
                try:
                    response_dict = json.loads(str)  # convert string to dictionary
                    if "choices" in response_dict and response_dict["choices"]:
                        text = response_dict["choices"][0]["delta"].get("content")
                        if text:
                            text = text.replace('\\n','\n').replace('\\t','\t')
                            
                            text = text.replace('\\n','\n')
                            all_text += text
                except:
                    pass
    print(all_text)
    return all_text